{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Volumetric data processing\n",
    "This is a simple demo on toy 3d data for source extraction and deconvolution using CaImAn.\n",
    "For more information check demo_pipeline.ipynb which performs the complete pipeline for\n",
    "2d two photon imaging data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import get_ipython\n",
    "import logging \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from scipy.ndimage import gaussian_filter\n",
    "from tifffile.tifffile import imwrite\n",
    "\n",
    "import caiman as cm\n",
    "from caiman.utils.visualization import nb_view_patches3d\n",
    "import caiman.source_extraction.cnmf as cnmf\n",
    "from caiman.paths import caiman_datadir\n",
    "\n",
    "try:\n",
    "    if __IPYTHON__:\n",
    "        get_ipython().run_line_magic('load_ext', 'autoreload')\n",
    "        get_ipython().run_line_magic('autoreload', '2')\n",
    "except NameError:\n",
    "    pass\n",
    "\n",
    "import bokeh.plotting as bpl\n",
    "bpl.output_notebook()\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.WARNING)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "# Set play_movies to False if you want to disable play of movies, e.g. for remote-hosted Jupyter environments\n",
    "play_movies = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function to create some toy data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(p=1, noise=.5, T=256, framerate=30, firerate=2., motion=True, plot=False):\n",
    "    if p == 2:\n",
    "        gamma = np.array([1.5, -.55])\n",
    "    elif p == 1:\n",
    "        gamma = np.array([.9])\n",
    "    else:\n",
    "        raise\n",
    "    dims = (70, 50, 10)  # size of image\n",
    "    sig = (4, 4, 2)  # neurons size\n",
    "    bkgrd = 10\n",
    "    N = 20  # number of neurons\n",
    "    np.random.seed(0)\n",
    "    centers = np.asarray([[np.random.randint(s, x - s)\n",
    "                           for x, s in zip(dims, sig)] for i in range(N)])\n",
    "    if motion:\n",
    "        centers += np.array(sig) * 2\n",
    "        Y = np.zeros((T,) + tuple(np.array(dims) + np.array(sig) * 4), dtype=np.float32)      \n",
    "    else:\n",
    "        Y = np.zeros((T,) + dims, dtype=np.float32)\n",
    "    trueSpikes = np.random.rand(N, T) < firerate / float(framerate)\n",
    "    trueSpikes[:, 0] = 0\n",
    "    truth = trueSpikes.astype(np.float32)\n",
    "    for i in range(2, T):\n",
    "        if p == 2:\n",
    "            truth[:, i] += gamma[0] * truth[:, i - 1] + gamma[1] * truth[:, i - 2]\n",
    "        else:\n",
    "            truth[:, i] += gamma[0] * truth[:, i - 1]\n",
    "    for i in range(N):\n",
    "        Y[:, centers[i, 0], centers[i, 1], centers[i, 2]] = truth[i]\n",
    "    tmp = np.zeros(dims)\n",
    "    tmp[tuple(np.array(dims)//2)] = 1.\n",
    "    z = np.linalg.norm(gaussian_filter(tmp, sig).ravel())\n",
    "    Y = bkgrd + noise * np.random.randn(*Y.shape) + 10 * gaussian_filter(Y, (0,) + sig) / z\n",
    "    if motion:\n",
    "        shifts = np.transpose([np.convolve(np.random.randn(T-10), np.ones(11)/11*s) for s in sig])\n",
    "        Y = np.array([cm.motion_correction.apply_shifts_dft(img, (sh[0], sh[1], sh[2]), 0,\n",
    "                                                                     is_freq=False, border_nan='copy')\n",
    "                               for img, sh in zip(Y, shifts)])\n",
    "        Y = Y[:, 2*sig[0]:-2*sig[0], 2*sig[1]:-2*sig[1], 2*sig[2]:-2*sig[2]]\n",
    "    else:\n",
    "        shifts = None\n",
    "    T, d1, d2, d3 = Y.shape\n",
    "\n",
    "    if plot:\n",
    "        Cn = cm.local_correlations(Y, swap_dim=False)\n",
    "        plt.figure(figsize=(15, 3))\n",
    "        plt.plot(truth.T)\n",
    "        plt.figure(figsize=(15, 3))\n",
    "        for c in centers:\n",
    "            plt.plot(Y[c[0], c[1], c[2]])\n",
    "\n",
    "        d1, d2, d3 = dims\n",
    "        x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))\n",
    "        scale = 6/x\n",
    "        fig = plt.figure(figsize=(scale*x, scale*y))\n",
    "        axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])\n",
    "        plt.imshow(Cn.max(2).T, cmap='gray')\n",
    "        plt.scatter(*centers.T[:2], c='r')\n",
    "        plt.title('Max.proj. z')\n",
    "        plt.xlabel('x')\n",
    "        plt.ylabel('y')\n",
    "        axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])\n",
    "        plt.imshow(Cn.max(0), cmap='gray')\n",
    "        plt.scatter(*centers.T[:0:-1], c='r')\n",
    "        plt.title('Max.proj. x')\n",
    "        plt.xlabel('z')\n",
    "        plt.ylabel('y')\n",
    "        axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])\n",
    "        plt.imshow(Cn.max(1).T, cmap='gray')\n",
    "        plt.scatter(*centers.T[np.array([0,2])], c='r')\n",
    "        plt.title('Max.proj. y')\n",
    "        plt.xlabel('x')\n",
    "        plt.ylabel('z');\n",
    "        plt.show()\n",
    "\n",
    "    return Y, truth, trueSpikes, centers, dims, -shifts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select file(s) to be processed\n",
    "- create a file with a toy 3d dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = os.path.join(caiman_datadir(), 'example_movies', 'demoMovie3D.tif')\n",
    "Y, truth, trueSpikes, centers, dims, shifts = gen_data(p=2)\n",
    "imwrite(fname, Y)\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display the raw movie (optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show a max-projection of the correlation image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = cm.load(fname)\n",
    "Cn = cm.local_correlations(Y, swap_dim=False)\n",
    "d1, d2, d3 = dims\n",
    "x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))\n",
    "scale = 6/x\n",
    "fig = plt.figure(figsize=(scale*x, scale*y))\n",
    "axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])\n",
    "plt.imshow(Cn.max(2).T, cmap='gray')\n",
    "plt.title('Max.proj. z')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])\n",
    "plt.imshow(Cn.max(0), cmap='gray')\n",
    "plt.title('Max.proj. x')\n",
    "plt.xlabel('z')\n",
    "plt.ylabel('y')\n",
    "axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])\n",
    "plt.imshow(Cn.max(1).T, cmap='gray')\n",
    "plt.title('Max.proj. y')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('z');\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Play the movie (optional). This will require loading the movie in memory which in general is not needed by the pipeline. Displaying the movie uses the OpenCV library. Press `q` to close the video panel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if play_movies:\n",
    "    Y[...,5].play(magnification=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup a cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)\n",
    "if 'dview' in locals():\n",
    "    cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Motion Correction\n",
    "First we create a motion correction object with the parameters specified. Note that the file is not loaded in memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# motion correction parameters\n",
    "opts_dict = {'fnames': fname,\n",
    "            'strides': (24, 24, 6),    # start a new patch for pw-rigid motion correction every x pixels\n",
    "            'overlaps': (12, 12, 2),   # overlap between patches (size of patch strides+overlaps)\n",
    "            'max_shifts': (4, 4, 2),   # maximum allowed rigid shifts (in pixels)\n",
    "            'max_deviation_rigid': 5,  # maximum shifts deviation allowed for patch with respect to rigid shifts\n",
    "            'pw_rigid': False,         # flag for performing non-rigid motion correction\n",
    "            'is3D': True}\n",
    "\n",
    "opts = cnmf.params.CNMFParams(params_dict=opts_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first we create a motion correction object with the parameters specified\n",
    "mc = cm.motion_correction.MotionCorrect(fname, dview=dview, **opts.get_group('motion'))\n",
    "# note that the file is not loaded in memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "#%% Run motion correction using NoRMCorre\n",
    "mc.motion_correct(save_movie=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12,3))\n",
    "for i, s in enumerate((mc.shifts_rig, shifts)):\n",
    "    plt.subplot(1,2,i+1)\n",
    "    for k in (0,1,2):\n",
    "        plt.plot(np.array(s)[:,k], label=('x','y','z')[k])\n",
    "    plt.legend()\n",
    "    plt.title(('inferred shifts', 'true shifts')[i])\n",
    "    plt.xlabel('frames')\n",
    "    plt.ylabel('pixels')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Memory mapping \n",
    "\n",
    "The cell below memory maps the file in order `'C'` and then loads the new memory mapped file. The saved files from motion correction are memory mapped files stored in `'F'` order. Their paths are stored in `mc.mmap_file`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% MEMORY MAPPING\n",
    "# memory map the file in order 'C'\n",
    "fname_new = cm.save_memmap(mc.mmap_file, base_name='memmap_', order='C',\n",
    "                           border_to_0=0, dview=dview) # exclude borders\n",
    "\n",
    "# now load the file\n",
    "Yr, dims, T = cm.load_memmap(fname_new)\n",
    "images = np.reshape(Yr.T, [T] + list(dims), order='F') \n",
    "    #load frames in python format (T x X x Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now restart the cluster to clean up memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% restart cluster to clean up memory\n",
    "cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## If data is small enough use a single patch approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters\n",
    "K = 20  # number of neurons expected per patch\n",
    "gSig = [4, 4, 2]  # expected half size of neurons\n",
    "merge_thresh = 0.8  # merging threshold, max correlation allowed\n",
    "p = 2  # order of the autoregressive system"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run CNMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# INIT\n",
    "cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview)\n",
    "cnm.params.set('spatial', {'se': np.ones((3,3,1), dtype=np.uint8)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "# FIT\n",
    "cnm.fit(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View the results\n",
    "View components per plane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims, axis=2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For larger data use a patch approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters\n",
    "rf = 15  # half-size of the patches in pixels. rf=25, patches are 50x50\n",
    "stride = 10  # amount of overlap between the patches in pixels\n",
    "K = 10  # number of neurons expected per patch\n",
    "gSig = [4, 4, 2]  # expected half size of neurons\n",
    "merge_thresh = 0.8  # merging threshold, max correlation allowed\n",
    "p = 2  # order of the autoregressive system\n",
    "print('set')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run CNMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "cnm = cnmf.CNMF(n_processes, \n",
    "                k=K, \n",
    "                gSig=gSig, \n",
    "                merge_thresh=merge_thresh, \n",
    "                p=p, \n",
    "                dview=dview,\n",
    "                rf=rf, \n",
    "                stride=stride, \n",
    "                only_init_patch=True)\n",
    "cnm.params.set('spatial', {'se': np.ones((3,3,1), dtype=np.uint8)})\n",
    "cnm.fit(images)\n",
    "print(('Number of components:' + str(cnm.estimates.A.shape[-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Component Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% COMPONENT EVALUATION\n",
    "# the components are evaluated in two ways:\n",
    "#   a) the shape of each component must be correlated with the data\n",
    "#   b) a minimum peak SNR is required over the length of a transient\n",
    "\n",
    "fr = 10 # approx final rate  (after eventual downsampling )\n",
    "decay_time = 1.  # length of typical transient in seconds \n",
    "use_cnn = False  # CNN classifier is designed for 2d (real) data\n",
    "min_SNR = 3      # accept components with that peak-SNR or higher\n",
    "rval_thr = 0.6   # accept components with space correlation threshold or higher\n",
    "cnm.params.change_params(params_dict={'fr': fr,\n",
    "                                      'decay_time': decay_time,\n",
    "                                      'min_SNR': min_SNR,\n",
    "                                      'rval_thr': rval_thr,\n",
    "                                      'use_cnn': use_cnn})\n",
    "\n",
    "cnm.estimates.evaluate_components(images, cnm.params, dview=dview)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(('Keeping ' + str(len(cnm.estimates.idx_components)) +\n",
    "       ' and discarding  ' + str(len(cnm.estimates.idx_components_bad))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Re-run seeded CNMF\n",
    "Now we re-run CNMF on the whole FOV seeded with the accepted components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "cnm.params.set('temporal', {'p': p})\n",
    "cnm2 = cnm.refit(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View the results\n",
    "Unlike the above layered view, here we view the components as max-projections (frontal in the XY direction, sagittal in YZ direction and transverse in XZ), and we also show the denoised trace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm2.estimates.nb_view_components_3d(image_type='corr', \n",
    "                                     dims=dims, \n",
    "                                     Yr=Yr, \n",
    "                                     denoised_color='red', \n",
    "                                     max_projection=True,\n",
    "                                     axis=2);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# STOP CLUSTER\n",
    "cm.stop_server(dview=dview)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "caiman_pytorch2",
   "language": "python",
   "name": "caiman_pytorch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
